"""WebSocket适配器 - 支持双向通信的IM系统"""

import asyncio
import logging
import typing
from datetime import datetime

import pydantic

import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from ...core import app
from .websocket_manager import ws_connection_manager, WebSocketConnection

logger = logging.getLogger(__name__)


class WebSocketMessage(pydantic.BaseModel):
    """WebSocket消息格式"""

    id: int
    role: str  # 'user' or 'assistant'
    content: str
    message_chain: list[dict]
    timestamp: str
    is_final: bool = False
    connection_id: str = ''
    """发送者连接ID"""


class WebSocketSession:
    """WebSocket会话 - 管理单个会话的消息历史"""

    id: str
    message_lists: dict[str, list[WebSocketMessage]] = {}
    """消息列表 {pipeline_uuid: [messages]}"""

    def __init__(self, id: str):
        self.id = id
        self.message_lists = {}

    def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
        if pipeline_uuid not in self.message_lists:
            self.message_lists[pipeline_uuid] = []
        return self.message_lists[pipeline_uuid]


class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
    """WebSocket适配器 - 支持双向实时通信"""

    websocket_person_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)
    websocket_group_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)

    listeners: dict[
        typing.Type[platform_events.Event],
        typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
    ] = pydantic.Field(default_factory=dict, exclude=True)

    ap: app.Application = pydantic.Field(exclude=True)

    # 主动推送消息的队列
    outbound_message_queue: asyncio.Queue = pydantic.Field(default_factory=asyncio.Queue, exclude=True)
    """后端主动推送消息的队列"""

    def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
        super().__init__(
            config=config,
            logger=logger,
            **kwargs,
        )

        self.websocket_person_session = WebSocketSession(id='websocketperson')
        self.websocket_group_session = WebSocketSession(id='websocketgroup')

        self.bot_account_id = 'websocketbot'
        self.outbound_message_queue = asyncio.Queue()

    async def send_message(
        self,
        target_type: str,
        target_id: str,
        message: platform_message.MessageChain,
    ) -> dict:
        """发送消息 - 这里用于主动推送消息到前端"""
        message_data = {
            'type': 'bot_message',
            'target_type': target_type,
            'target_id': target_id,
            'content': str(message),
            'message_chain': [component.__dict__ for component in message],
            'timestamp': datetime.now().isoformat(),
        }

        # 推送到所有相关连接
        await self.outbound_message_queue.put(message_data)

        return message_data

    async def reply_message(
        self,
        message_source: platform_events.MessageEvent,
        message: platform_message.MessageChain,
        quote_origin: bool = False,
    ) -> dict:
        """回复消息 - 非流式"""
        # 获取会话和pipeline信息
        session = (
            self.websocket_group_session
            if isinstance(message_source, platform_events.GroupMessage)
            else self.websocket_person_session
        )

        # 从message_source获取pipeline_uuid和connection_id
        pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
        session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'

        # 生成新的消息ID
        msg_id = len(session.get_message_list(pipeline_uuid)) + 1

        message_data = WebSocketMessage(
            id=msg_id,
            role='assistant',
            content=str(message),
            message_chain=[component.__dict__ for component in message],
            timestamp=datetime.now().isoformat(),
            is_final=True,
        )

        # 保存到历史记录
        session.get_message_list(pipeline_uuid).append(message_data)

        # 直接广播到所有该pipeline的连接，包含session_type信息
        await ws_connection_manager.broadcast_to_pipeline(
            pipeline_uuid,
            {
                'type': 'response',
                'session_type': session_type,
                'data': message_data.model_dump(),
            },
            session_type=session_type,
        )

        return message_data.model_dump()

    async def reply_message_chunk(
        self,
        message_source: platform_events.MessageEvent,
        bot_message,
        message: platform_message.MessageChain,
        quote_origin: bool = False,
        is_final: bool = False,
    ) -> dict:
        """回复消息块 - 流式"""
        # 获取会话和pipeline信息
        session = (
            self.websocket_group_session
            if isinstance(message_source, platform_events.GroupMessage)
            else self.websocket_person_session
        )

        pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
        session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
        message_list = session.get_message_list(pipeline_uuid)

        # 检查是否是新的流式消息（通过bot_message对象判断）
        # 如果列表为空，或者最后一条消息已经is_final=True，则创建新消息
        if not message_list or message_list[-1].is_final:
            # 创建新消息
            msg_id = len(message_list) + 1
            message_data = WebSocketMessage(
                id=msg_id,
                role='assistant',
                content=str(message),
                message_chain=[component.__dict__ for component in message],
                timestamp=datetime.now().isoformat(),
                is_final=is_final and bot_message.tool_calls is None,
            )

            # 只有在is_final时才保存到历史记录
            if is_final and bot_message.tool_calls is None:
                message_list.append(message_data)
        else:
            # 更新最后一条消息
            msg_id = message_list[-1].id
            message_data = WebSocketMessage(
                id=msg_id,
                role='assistant',
                content=str(message),
                message_chain=[component.__dict__ for component in message],
                timestamp=message_list[-1].timestamp,  # 保持原始时间戳
                is_final=is_final and bot_message.tool_calls is None,
            )

            # 如果是final，更新历史记录中的最后一条
            if is_final and bot_message.tool_calls is None:
                message_list[-1] = message_data

        # 直接广播到所有该pipeline的连接，包含session_type信息
        await ws_connection_manager.broadcast_to_pipeline(
            pipeline_uuid,
            {
                'type': 'response',
                'session_type': session_type,
                'data': message_data.model_dump(),
            },
            session_type=session_type,
        )

        return message_data.model_dump()

    async def is_stream_output_supported(self) -> bool:
        """WebSocket始终支持流式输出"""
        return True

    def register_listener(
        self,
        event_type: typing.Type[platform_events.Event],
        func: typing.Callable[
            [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
        ],
    ):
        """注册事件监听器"""
        self.listeners[event_type] = func

    def unregister_listener(
        self,
        event_type: typing.Type[platform_events.Event],
        func: typing.Callable[
            [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
        ],
    ):
        """取消注册事件监听器"""
        del self.listeners[event_type]

    async def is_muted(self, group_id: int) -> bool:
        return False

    async def run_async(self):
        """运行适配器"""

        try:
            while True:
                # 处理主动推送消息
                if not self.outbound_message_queue.empty():
                    try:
                        message = await asyncio.wait_for(self.outbound_message_queue.get(), timeout=0.1)
                        # 广播到所有相关连接
                        target_id = message.get('target_id', '')
                        await ws_connection_manager.broadcast_to_pipeline(target_id, message)
                    except asyncio.TimeoutError:
                        pass

                await asyncio.sleep(0.1)
        except asyncio.CancelledError:
            raise

    async def kill(self):
        """停止适配器"""
        pass

    async def _process_image_components(self, message_chain_obj: list):
        """
        处理消息链中的图片组件，将path转换为base64

        Args:
            message_chain_obj: 消息链对象列表
        """
        import base64

        storage_mgr = self.ap.storage_mgr

        for component in message_chain_obj:
            if component.get('type') == 'Image' and component.get('path'):
                try:
                    # 从storage读取文件
                    file_content = await storage_mgr.storage_provider.load(component['path'])

                    # 转换为base64
                    base64_str = base64.b64encode(file_content).decode('utf-8')

                    # 添加data URI前缀（根据文件扩展名判断MIME类型）
                    file_key = component['path']
                    if file_key.lower().endswith(('.jpg', '.jpeg')):
                        mime_type = 'image/jpeg'
                    elif file_key.lower().endswith('.png'):
                        mime_type = 'image/png'
                    elif file_key.lower().endswith('.gif'):
                        mime_type = 'image/gif'
                    elif file_key.lower().endswith('.webp'):
                        mime_type = 'image/webp'
                    else:
                        mime_type = 'image/png'  # 默认

                    component['base64'] = f'data:{mime_type};base64,{base64_str}'
                    await storage_mgr.storage_provider.delete(component['path'])
                    component['path'] = ''
                    # 保留path字段用于后端处理，前端使用base64显示
                except Exception as e:
                    await self.logger.error(f'加载图片文件失败 {component["path"]}: {e}')

    async def handle_websocket_message(
        self,
        connection: WebSocketConnection,
        message_data: dict,
    ):
        """
        处理从WebSocket接收的消息

        这个方法只负责接收消息、保存到历史记录、并触发事件处理
        不等待任何响应，响应消息会通过reply_message/reply_message_chunk直接发送

        Args:
            connection: WebSocket连接对象
            message_data: 消息数据
        """
        pipeline_uuid = connection.pipeline_uuid
        session_type = connection.session_type

        # 选择会话
        use_session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session

        # 解析消息链
        message_chain_obj = message_data.get('message', [])

        # 处理图片组件：将path转换为base64
        await self._process_image_components(message_chain_obj)

        message_chain = platform_message.MessageChain.model_validate(message_chain_obj)

        # 生成消息ID
        message_id = len(use_session.get_message_list(pipeline_uuid)) + 1

        # 保存用户消息
        user_message = WebSocketMessage(
            id=message_id,
            role='user',
            content=str(message_chain),
            message_chain=message_chain_obj,
            timestamp=datetime.now().isoformat(),
            connection_id=connection.connection_id,
            is_final=True,  # 用户消息始终是完整的，非流式
        )
        use_session.get_message_list(pipeline_uuid).append(user_message)

        # 广播用户消息到所有连接（包括发送者），包含session_type信息
        await ws_connection_manager.broadcast_to_pipeline(
            pipeline_uuid,
            {
                'type': 'user_message',
                'session_type': session_type,
                'data': user_message.model_dump(),
            },
            session_type=session_type,
        )

        # 添加消息源
        message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp()))

        # 创建事件
        if session_type == 'person':
            sender = platform_entities.Friend(
                id=f'websocket_{connection.connection_id}', nickname='User', remark='User'
            )
            event = platform_events.FriendMessage(
                sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
            )
        else:
            group = platform_entities.Group(
                id='websocketgroup', name='Group', permission=platform_entities.Permission.Member
            )
            sender = platform_entities.GroupMember(
                id=f'websocket_{connection.connection_id}',
                member_name='User',
                group=group,
                permission=platform_entities.Permission.Member,
            )
            event = platform_events.GroupMessage(
                sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
            )

        # 设置流水线UUID
        self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid

        # 异步触发事件处理（不等待结果）
        if event.__class__ in self.listeners:
            asyncio.create_task(self.listeners[event.__class__](event, self))

    def get_websocket_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]:
        """获取消息历史"""
        if session_type == 'person':
            return [message.model_dump() for message in self.websocket_person_session.get_message_list(pipeline_uuid)]
        else:
            return [message.model_dump() for message in self.websocket_group_session.get_message_list(pipeline_uuid)]

    def reset_session(self, pipeline_uuid: str, session_type: str):
        """重置会话"""
        if session_type == 'person':
            if pipeline_uuid in self.websocket_person_session.message_lists:
                self.websocket_person_session.message_lists[pipeline_uuid] = []
        else:
            if pipeline_uuid in self.websocket_group_session.message_lists:
                self.websocket_group_session.message_lists[pipeline_uuid] = []
